# -*- coding: utf-8 -*-
"""
Created on Tue Sep 20 15:42:20 2022

@author: zhiyinan
"""
import numpy as np
from casadi import *
import mpctools as mpc
import matplotlib.pyplot as plt
import time as time
import matplotlib.font_manager as font_manager

# %% Define the system

def sys(x, u, w):
    q = 100 # [L/min]
    V = 100 # [L]
    C_Af = 1 + w[0] # [mol/L] w[0]
    T_f = 350 + w[1] # [k]  w[1]
    EdR = 8750 # [k]
    k_0 = 7.2e10 # [min**-1]
    del_H = -5e4 # [J/mol]
    UA = 5e4 # [J/min/k]
    C_p = 0.239 #[J/g/k]
    rho = 1000 # [g/L]
    
    C_A = x[0]
    T = x[1]
    
    T_c = u
    
    dt = 0.05
    for i in range(int(0.1/dt)):
    #     # C_A = C_A + dt * ()
    #     # T = T + dt * ()
    #     # x[0] = x[0] + dt * (q/V*(w[0] - x[0]) - k_0*exp(-EdR/x[1])*x[0])
    #     # x[1] = x[1] + dt * (q/V*(w[1] - x[1]) + (-del_H)/rho/C_p * k_0 * exp(-EdR/x[1])*x[0] + UA/V/rho/C_p * (u - x[1]))
        C_A = C_A + dt * ((q/V)*(C_Af - C_A) - k_0*exp(-EdR/T)*C_A)
        T = T + dt * ((q/V)*(T_f - T) + (-del_H/(rho*C_p))*k_0*exp(-EdR/T)*C_A + (UA/(V*rho*C_p))*(-T) + (UA/(V*rho*C_p))*T_c)
    dx = []
    dx = vertcat(dx, C_A)
    dx = vertcat(dx, T)
    # dx = vertcat(dx, q/V*(C_Af - C_A) - k_0*exp(-EdR/T)*C_A)
    # dx = vertcat(dx, q/V*(T_f - T) + (-del_H)/rho/C_p * k_0 * exp(-EdR/T)*C_A - + UA/V/rho/C_p * (T_c - T))
    # dx = vertcat(dx, (q/V)*(C_Af - x[0]) - k_0*exp(-EdR/x[1])*x[0])
    # dx = vertcat(dx, (q/V)*(T_f - x[1]) + (-del_H/(rho*C_p))*k_0*exp(-EdR/x[1])*x[0] + (UA/(V*rho*C_p))*(-x[1]) + (UA/(V*rho*C_p))*u[0])
    return dx

Nu = 1
Nx = 2

u = MX.sym('u',Nu)
x = MX.sym('x',Nx)
w = MX.sym('w',2)

p = vertcat(u,w)

Delta = 0.1 
dt_in = 0.1
Nt = 1  # Control horizon 
Nsim = 5  # Simulation time


# ode = {'x':x, 'p': p, 'ode':sys(x, p[0], p[1:])}
# opts = {'tf': dt_in, 'regularity_check': True}
# F = integrator('F', 'cvodes', ode, opts)



# def sim_sys(x,u,w):
#     for i in range(int(Delta/dt_in)):
#         r = F(x0 = x, p = vertcat(u,w))
#         x = r["xf"]
#     return x

# Constraints
lbx = np.array([0, 345])
ubx = np.array([1, 355]) # m

lbu = 285
ubu = 315



# ======================Nominal target zone (on x)======================
lbx_zn = 348
ubx_zn = 352
# lbx_zn = 1.5
# ubx_zn = 8.5

# ===========Control Invariant terminal set===========
# Original set
Ha = DM(np.load("cis_song_toMia/results/hrepABig.npy"))
Hb = DM(np.load("cis_song_toMia/results/hrepBBig.npy"))


Ha_s = DM(np.load("cis_song_toMia/results/hrepAPM0_gamma1.npy"))
Hb_s = DM(np.load("cis_song_toMia/results/hrepBPM0_gamma1.npy"))
# ======================Controller hyperparemeters======================


# =============================Disturbance==============================
# lbw = np.array([0.9, 348])
# ubw = np.array([1.1, 352])

lbw = np.array([-0.1, -2])
ubw = np.array([0.1, 2])
# z_t = np.ones(Nsim+Nt)*zs
# True input 
np.random.seed(5)
w_t = np.zeros((int((Nsim+Nt)/Delta), 2)) #+ zs
for i in range(int((Nsim+Nt)/Delta)):
    w_t[i,0] = np.random.rand()*(ubw[0] - lbw[0])+lbw[0]
    w_t[i,1] = np.random.rand()*(ubw[1] - lbw[1])+lbw[1]
    # z_t[i] = (ubz+lbz)/2 + (ubz+lbz)/2*sin(0.3*i)
# z_tc = np.zeros(Nsim+Nt)
# z_tc[1:] = z_t[:-1]
# plt.figure(1)
# plt.plot(w_t[:,1])

# x0 = 8
# nlp = {'x':vertcat(x,u), 'f': x[0], 'g':sys(x,u,np.zeros(2))-x}
# S = nlpsol('S','ipopt',nlp)
# r = S(x0 = np.array([0.5, 350, 300]), lbx = vertcat(np.array([0, 348]), lbu), ubx = vertcat(np.array([1, 352]), ubu), lbg = 0, ubg = 0)
# xs = r['x'][:Nx]
# us = r['x'][Nx:]
# %% Calculate the sensitivty of the state w.r.t. the disturbance
dxdw = Function('dxdw', [x,u,w], [jacobian(sys(x, u, w), w)])
#ex = dxdw(xs,us,ubw)@ubw

xd_max = np.array(dxdw(np.array([0.75390625, 352]), 315, ubw)@ubw)

gamma = 1
s = gamma*xd_max[1]
# Target zone w.r.t x
lbx_z = lbx_zn + s
ubx_z = ubx_zn - s

# %% Define the proposed ZMPC controller
def obj(x, xz): #,u, u_prev, uz
    Q = 10
    # c2 = 100000
    lz = Q*(x[1] - xz)**2 #c2*(norm_2(u - uz))**2 + #+ c1*(norm_1(u - uz))  
    # le = (u - 0.9)**2 + norm_1(u-0.9)
    # le = u**2
    # le = (u-u_prev)**2
    le = x[0]
    return lz #+ le

W = [] 
W0 = []
W_lb = []
W_ub = []
f = 0
g = []
p = []
# pn = []
# lbg = []
# ubg = []

# Define the unscaled variables
u = SX.sym('u',(int(Nt/Delta), Nu))
x = SX.sym('x',(int(Nt/Delta)+1, Nx))
z = SX.sym('z', (int(Nt/Delta), 2))
# x_0 = SX.sym('x_0',Nx)
# uz = SX.sym('uz',(int(Nt/Delta), Nu))
xz = SX.sym('xz',(int(Nt/Delta), 1))


# xz = SX.sym('xz',(int(Nt/Delta)+1, Nx))
# z = SX.sym('z',(Nz,int(Nt/Delta+1)))

# Initial guesses of the input and the state
u_g = 299.413
x_g = np.array([0.9, 345])

# Initial conditions of the input and the state
x0 = np.array([0.9, 345])
xt = DM(x0)

# xt0 = MX.sym('xt0',Nx)
W = vertcat(W,x[0,:])
W_lb = vertcat(W_lb,x0)
W_ub = vertcat(W_ub,x0)
W0 = vertcat(W0,x0)

# W = horzcat(W,up)
# W_lb = vertcat(W_lb,u0)
# W_ub = vertcat(W_ub,u0)
# W0 = vertcat(W0,u0)
# temp = x_0
# u_p = up
for t in range(int(Nt/Delta)): 
    # u = MX.sym('u_' + str(t), Nu)
    W = horzcat(W,u[t]) 
    W_lb = vertcat(W_lb, lbu) # lbu
    W_ub = vertcat(W_ub, ubu) 
    W0 = vertcat(W0, u_g)
    
    # x = MX.sym('x_' + str(t), Nx)
    W = horzcat(W, x[t+1, :])
    W_lb = vertcat(W_lb, lbx)
    W_ub = vertcat(W_ub, ubx) # 15)
    W0 = vertcat(W0,x_g)
    
    # W = horzcat(W,uz[t,:])
    # W_lb = vertcat(W_lb,uz_lb)
    # W_ub = vertcat(W_ub,uz_ub)
    # W0 = vertcat(W0,u_g)
    # xz = MX.sym('xz_' + str(t))
    W = horzcat(W,xz[t,:])
    W_lb = vertcat(W_lb,lbx_zn)
    W_ub = vertcat(W_ub,ubx_zn)
    W0 = vertcat(W0,x_g[1])
    
    # z = MX.sym('z_' + str(t), 2)
    p = vertcat(p,z[t,:])
    # pn = vertcat(pn, DM([1, 350]))
    
    temp = sys(x[t,:], u[t,:], z[t,:])
    g = vertcat(g, (temp[0] - x[t+1,0]))
    g = vertcat(g, (temp[1] - x[t+1,1]))
    # lbg = vertcat(lbg, 0)
    # ubg = vertcat(ubg, 0)
    # xt = temp 
    f += obj(x[t+1,:], xz[t,:]) #u[t,:], u_p, xz[t,:]])
    # u_p = u[t,:]
# for t in range(int(Nt/Delta)-1):
#     temp = sys(temp, u[t,:])
#     g = vertcat(g, temp)
#     lbg = vertcat(lbg, -5)
#     ubg = vertcat(ubg, 5)

# Terminal Constraint
g = vertcat(g, Ha_s@x[-1,:].T - Hb_s)
# g = vertcat(g, 5-x[-1,:])
# g = vertcat(g, x[-1,:] - 10)
lbg = vertcat(np.zeros(int(Nt/Delta)*Nx), -inf*np.ones(Ha_s.shape[0]))  # , -inf, -inf

# # %%
# u_sim = vertcat(np.ones(40)*299.709, np.ones(60)*299.709)·
# x_sim = np.zeros((101,2))
# x_sim[0] = x0
# for i in range(100):
#     x_sim[i+1,:] = np.array(sys(x_sim[i,:],u_sim[i],np.zeros(2))).flatten()
# plt.figure(1)
# plt.plot(x_sim[:,0])

# plt.figure(2)
# plt.plot(x_sim[:,1])
# %% Solve the ZMPC
W = W.reshape((-1,1))   
nlp = {'x':W, 'p':p, 'f':f, 'g':g}
opts = {'ipopt.acceptable_tol':1e-3, 'ipopt.tol':1e-3} #'ipopt.hessian_approximation': 'limited-memory','ipopt.acceptable_tol':1e-3, 'ipopt.tol':1e-3} # } # 
S = nlpsol('S','ipopt',nlp, opts)

# ode_sim = mpc.DiscreteSimulator(dyn.ode, Delta, [Nx, Nu], ['x', 'u'])

u_opt = DM.zeros((int(Nsim/Delta),Nu))
x_z = DM.zeros((int(Nsim/Delta),Nx))
x_opt = DM.zeros((int(Nsim/Delta)+1,Nx))


start = time.time()
x_opt[0,:] = x0
# pt = zs
for i in range(int(Nsim/Delta)): # int(Nsim/Delta)
    W_lb[:Nx] = x0 #DM([x0, u0])
    W_ub[:Nx] = x0 #DM([x0, u0])
    W0[:Nx] = x0 #DM([x0, u0])
    r = S(x0 = W0, lbx = W_lb, ubx = W_ub, lbg = lbg, ubg = 0, p = 0)   # p = z_t[i:i+Nt]  (lbz+ubz)/2
    u_i = r['x'][Nx:Nu+Nx]
    # y_i = r['x'][Nu+Ny:Nu+Ny*2]
    u_opt[i,:] = u_i
    x_z[i,:] = r['x'][2*Nx+Nu:3*Nx+Nu]
    x0 = np.array(sys(x0, u_i, w_t[i,:])) # w_t[i,:] np.zeros(2)
    u0 = u_i
    # y_i = x0[17]
    # y0 = y_i
    x_opt[i+1,:] = x0
    # pt = z_t[i]
zone_v = np.zeros(int(Nsim/Delta))
for i in range(int(Nsim/Delta)):
    zone_v[i] = max(0, x_opt[i,1] - ubx_zn)

zone_c = 0
for i in range(int(Nsim/Delta)):
    Q = 10
    if x_opt[i,1] <=lbx_zn:
        zone_c += Q*(x_opt[i,1] - lbx_zn)**2
    elif x_opt[i,1] >= ubx_zn:
        zone_c += Q*(x_opt[i,1] - ubx_zn)**2
            
    # c2 = 100000
    
# obj_tot = 0
# for i in range(int(Nsim/Delta)):
#     obj_tot +=obj(x_opt[i+1, :], x_z[i]) # u_opt[i], u_opt[i-1], 
# print(obj_tot)
# plt.figure(1)
# plt.plot(u_sol)
# plt.plot(uz_sol)
# for i in range()    


t_sim = time.time() - start
print(t_sim/int(Nsim/Delta))

x_opt = np.array(x_opt)
u_opt = np.array(u_opt).flatten()




# np.savetxt('CSTR_results/x_opt_proposed_012_355.txt', x_opt)
# np.savetxt('CSTR_results/u_opt_proposed_012_355.txt', u_opt)
# %%
# dx = x_opt[1:] - x_opt[:-1]
# du = u_opt[1:] - u_opt[:-1]


# For the original CIS
x2S = np.arange(lbx_zn,ubx_zn+0.1,0.1) #.reshape((1,100))
# uS = np.arange(0,1,0.01).reshape((1,100))
x1S = np.zeros((Ha.shape[0], x2S.shape[0]))

for i in range(Ha.shape[0]):
    x1S[i,:] = np.array((-Ha[i, 1]*x2S + Hb[i])/Ha[i, 0]).flatten()
a = x1S[2,:]
for i in range(3, 11):
    a = np.minimum(a, x1S[i,:])
b = x1S[1,:]
for i in range(11,Ha.shape[0]):
    b = np.maximum(b, x1S[i,:])


# For the Shrunk CIS
x2S_s = np.arange(lbx_z,ubx_z+0.01,0.01) #.reshape((1,100))
# uS = np.arange(0,1,0.01).reshape((1,100))
x1S_s = np.zeros((Ha_s.shape[0], x2S_s.shape[0]))

for i in range(Ha_s.shape[0]):
    x1S_s[i,:] = np.array((-Ha_s[i, 1]*x2S_s + Hb_s[i])/Ha_s[i, 0]).flatten()
a_s = x1S_s[2,:]
for i in range(3, 12):
    a_s = np.minimum(a_s, x1S_s[i,:])
b_s = x1S_s[1,:]
for i in range(14,Ha_s.shape[0]):
    b_s = np.maximum(b_s, x1S_s[i,:])
plt.figure(1)
plt.ylim([0, 1])
# plt.plot(x2S_s, x1S_s[23,:])
plt.plot(x2S, a)
plt.plot(x2S, b)
plt.plot(x2S_s, a_s)
plt.plot(x2S_s, b_s)
# gamma = 3 a_s 9  b_s 10
# gamma = 1 a_s 12  b_s 14
# gamma = 0.5 a_s 11  b_s 12
# gamma = 0.6 a_s 12  b_s 13
# gamma = 0.7 a_s 11  b_s 12
# plotting the lines
# S_up = np.minimum(np.maximum(uS[0], np.ones(xS.shape[0])*uz_lb),  np.ones(xS.shape[0])*uz_ub)
# S_low = np.minimum(np.maximum(uS[1], np.ones(xS.shape[0])*uz_lb),  np.ones(xS.shape[0])*uz_ub)
# S_up[50]=0.3
# S_low = np.maximum(S[1], S[3])
# The upper edge of
# polygon
# up = np.arange(uz_lb, uz_ub+0.1, 0.1)
# xp = np.arange(lbx_z, ubx_z+0.1, 0.1)
# %%

# x_opt0 = np.loadtxt('CSTR_results/x_opt_original_012_355.txt')
# x_opt1 = np.loadtxt('CSTR_results/x_opt_shrunk_CIS_n_09_345.txt') 
x_opt2 = np.loadtxt('CSTR_results/x_opt_shrunk_09_345.txt')
csfont = {'fontname':'Times New Roman'}
font = font_manager.FontProperties(family='Times New Roman',
                                    # weight='bold',
                                    style='normal', size=13)

# plt.figure(1)
# plt.plot(u_opt, label = r"Proposed")
# plt.legend(frameon=False, ncol=1, prop=font)


c = np.linspace(lbx[0],ubx[0],num = 51)
d = np.linspace(lbx[1],ubx[1],num = 51)
ax = plt.figure(3)
# plt.xlim([0, 1])
plt.plot(c, np.ones(c.shape[0])*lbx[1], color = "k")
plt.plot(c, np.ones(c.shape[0])*ubx[1], color = "k")
plt.plot(np.ones(d.shape[0])*lbx[0], d, color = "k")
plt.plot(np.ones(d.shape[0])*ubx[0], d, color = "k") #, label = r'$\mathbb{X}$')

plt.plot(c, np.ones(c.shape[0])*lbx_zn, linestyle = "--", color = "b") #, label = 'Actual target') #, label = r'$\mathbb{X}_t$' )
plt.plot(c, np.ones(c.shape[0])*ubx_zn, linestyle = "--", color = "b")
plt.plot(a[:-1], x2S[:-1], linestyle = ":", color = 'b')
plt.plot(b[:-1], x2S[:-1], linestyle = ":", color = 'b')
plt.fill_betweenx(x2S[:-1], a[:-1], b[:-1], alpha=0.2)

# plt.plot(c, np.ones(c.shape[0])*lbx_z, linestyle = "--", color = "r") #, label = r'$\gamma = 1$') #, label = r'$\tilde{\mathbb{X}}_t$')
# plt.plot(c, np.ones(c.shape[0])*ubx_z, linestyle = "--", color = "r")
# plt.plot(a_s[:-1], x2S_s[:-1], linestyle = ":", color = 'r')
# plt.plot(b_s[:-1], x2S_s[:-1], linestyle = ":", color = 'r')
# plt.fill_betweenx(x2S_s[:-1], a_s[:-1], b_s[:-1], alpha=0.2)

# plt.plot(np.ones(up.shape[0])*lbx_z, up, linestyle = "--", color = "r")
# plt.plot(np.ones(up.shape[0])*ubx_z, up, linestyle = "--", color = "r")
# plt.plot(x_opt[:,0], x_opt[:,1], marker = ".")
# plt.plot(x_opt0[:,0], x_opt0[:,1], marker = "h", color = 'b', label = "Original") 
plt.plot(x_opt[:,0], x_opt[:,1], linestyle = "--",color = "r", marker = ".", label = r"$w \neq 0$") 
# plt.plot(x_opt2[:,0], x_opt2[:,1], marker = ".", linestyle = "--",color = 'r', label = "Proposed") 
#r"$\gamma = 0.3$", r"$Shrunk \, \mathbb{Z}_t & \mathbb{Z}_f - w \neq 0$"$\in [\pm 0.1,\pm 2]$ $ label = r"Original CIS - $x_g$ = 8") # (u-u_{prev})^2
# $\tilde{\mathbb{X}}_t$, $\mathbb{X}_f = \tilde{\mathbb{X}}_t^M$
# for i in range(Nsim - 1): 
#     plt.arrow(x_opt[i+1], u_opt[i], dx[i+1], du[i], head_width = 0.01, head_length = 0.03, color = 'r')
# plt.plot(4,-1, marker = "*")
# plt.plot(-4,1, marker = "*")
plt.xlabel('$C_A \, (mol/L)$', **csfont, fontsize = 13)
plt.ylabel('$T (k)$', **csfont, fontsize = 13)
plt.legend(frameon=False, edgecolor = 'black', bbox_to_anchor=(0.9, 0.8), ncol=1, prop=font)
# plt.plot()
ax.set_rasterized(True)
plt.tight_layout()
# plt.savefig('Figures/CSTR_sets_gamma.eps')
# plt.savefig('Figures/CSTR_Original_proposed_compare_2.eps')



# %% Plot Zone_c
from matplotlib.markers import MarkerStyle
econ_p = np.array([26.41482482, 26.4947851, 26.53490166, 26.57510694, 26.69563801, 27.48351671])
zone_cp = np.array([98.3004, 97.417, 97.3434, 97.3731, 97.51456214, 98.5243])
gamma_p = np.array([0.3, 0.5, 0.6, 0.7, 1, 3])
fig, ax = plt.subplots()
ax.plot(gamma_p, zone_cp, marker = 'p', markerfacecolor ='none', color = 'b', label = "Zone-tracking cost")
ax.set_ylabel("Zone-tracking cost", **csfont, fontsize=14)
ax2 =ax.twinx()
ax2.plot(gamma_p, econ_p, marker = 'o',markerfacecolor ='none', color = 'r', label = "Economic cost")

# marker_style.update(markerfacecolor ='none')
ax2.set_ylabel("Economic cost", **csfont, fontsize=14)
ax.set_xlabel(r'$\gamma$')
fig.tight_layout()
fig.legend(frameon=False, edgecolor = 'black', bbox_to_anchor=(0.5, 0.8), ncol=1, prop=font)

## N = 3
# Shrunk CIS 130.828
# Original CIS 80.755

## N = 4
# Shrunk CIS 55.5797
# Original CIS 52.988